# db.py
import mysql.connector
from dotenv import load_dotenv
import os

load_dotenv()


def get_connection():
    return mysql.connector.connect(
        host=os.environ.get("DB_HOST"),
        user=os.environ.get("DB_USER"),
        password=os.environ.get("DB_PASSWORD"),
        database=os.environ.get("DB_NAME"),
    )


# ----- Session functions -----
def start_session(name):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "INSERT INTO sessions (name, start_time, status) VALUES (%s, NOW(), 'OPEN')",
        (name,),
    )
    db.commit()
    sid = cur.lastrowid
    cur.close()
    db.close()
    return {"id": sid, "name": name, "start_time": None, "status": "OPEN"}


def end_session():
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "UPDATE sessions SET end_time=NOW(), status='CLOSE' WHERE status='OPEN'"
    )
    db.commit()
    cur.close()
    db.close()
    return True


def get_current_session():
    db = get_connection()
    cur = db.cursor(dictionary=True)
    cur.execute("SELECT * FROM sessions WHERE status='OPEN' ORDER BY id DESC LIMIT 1")
    session = cur.fetchone()
    cur.close()
    db.close()
    return session


def add_verify_log(session_id, user_id, score, success, error=None):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "INSERT INTO verify_logs (session_id, user_id, score, verify_time, success, error) VALUES (%s, %s, %s, NOW(), %s, %s)",
        (session_id, user_id, score, int(success), error),
    )
    db.commit()
    cur.close()
    db.close()


def get_session_users_logs(session_id):
    db = get_connection()
    cur = db.cursor(dictionary=True)
    cur.execute(
        """
        SELECT
            u.id,
            u.student_code,
            u.fullname,
            u.year,
            u.department,
            u.photo,
            vl.verify_time,
            vl.score,
            IF(vl.success = 1, 1, 0) AS verified
        FROM users u
        LEFT JOIN (
            SELECT l1.*
            FROM verify_logs l1
            INNER JOIN (
                SELECT user_id, MAX(verify_time) AS max_time
                FROM verify_logs
                WHERE session_id = %s
                GROUP BY user_id
            ) l2
            ON l1.user_id = l2.user_id AND l1.verify_time = l2.max_time AND l1.session_id = %s
        ) vl ON u.id = vl.user_id
        ORDER BY u.fullname
        """,
        (session_id, session_id),
    )
    users = cur.fetchall()
    cur.close()
    db.close()
    return users


def check_verified_in_session(session_id, user_id):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "SELECT COUNT(*) FROM verify_logs WHERE session_id=%s AND user_id=%s AND success=1",
        (session_id, user_id),
    )
    already = cur.fetchone()[0]
    cur.close()
    db.close()
    return already > 0


def count_verified_in_session(session_id):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "SELECT COUNT(DISTINCT user_id) FROM verify_logs WHERE session_id=%s AND success=1",
        (session_id,),
    )
    cnt = cur.fetchone()[0]
    cur.close()
    db.close()
    return cnt


def count_users():
    db = get_connection()
    cur = db.cursor()
    cur.execute("SELECT COUNT(*) FROM users")
    count = cur.fetchone()[0]
    cur.close()
    db.close()
    return count


def insert_user(student_code, fullname, year, department, photo_bytes):
    db = get_connection()
    cur = db.cursor()
    sql = "INSERT INTO users (student_code, fullname, year, department, photo) VALUES (%s, %s, %s, %s, %s)"
    try:
        cur.execute(sql, (student_code, fullname, year, department, photo_bytes))
        user_id = cur.lastrowid
        db.commit()
    except mysql.connector.Error as e:
        db.rollback()
        if e.errno == 1062:
            raise ValueError("รหัสนักเรียนนี้ถูกใช้แล้ว กรุณาเปลี่ยนรหัสใหม่")
        else:
            raise
    finally:
        cur.close()
        db.close()
    return user_id


def fetch_user_by_code(student_code):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "SELECT id, student_code, fullname, year, department, photo FROM users WHERE student_code=%s",
        (student_code,),
    )
    user = cur.fetchone()
    cur.close()
    db.close()
    return user


def insert_fingerprint(user_id, sample_no, fmd_bytes):
    db = get_connection()
    cur = db.cursor()
    sql = "INSERT INTO fingerprints (user_id, sample_no, fmd) VALUES (%s, %s, %s)"
    cur.execute(sql, (user_id, sample_no, fmd_bytes))
    db.commit()
    cur.close()
    db.close()


def fetch_user_by_id(user_id):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "SELECT id, fullname, year, department, photo FROM users WHERE id=%s",
        (user_id,),
    )
    user = cur.fetchone()
    cur.close()
    db.close()
    return user


def fetch_fmds_by_user(user_id):
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        "SELECT sample_no, fmd FROM fingerprints WHERE user_id=%s ORDER BY sample_no",
        (user_id,),
    )
    samples = cur.fetchall()
    cur.close()
    db.close()
    return samples


def fetch_all_users():
    db = get_connection()
    cur = db.cursor()
    cur.execute("SELECT id, fullname, department, year FROM users")
    users = cur.fetchall()
    cur.close()
    db.close()
    return users


def fetch_all_users_with_fmds():
    db = get_connection()
    cur = db.cursor()
    cur.execute(
        """
        SELECT u.id, u.student_code, u.fullname, u.year, u.department, u.photo, f.sample_no, f.fmd
        FROM users u
        JOIN fingerprints f ON u.id = f.user_id
        ORDER BY u.id, f.sample_no
    """
    )
    data = cur.fetchall()
    cur.close()
    db.close()
    # group by user
    users = {}
    for row in data:
        user_id, student_code, fullname, year, department, photo, sample_no, fmd = row
        if user_id not in users:
            users[user_id] = {
                "id": user_id,
                "student_code": student_code,
                "fullname": fullname,
                "year": year,
                "department": department,
                "photo": photo,
                "samples": [],
            }
        users[user_id]["samples"].append((sample_no, fmd))
    return list(users.values())


def delete_user(user_id):
    db = get_connection()
    cur = db.cursor()
    cur.execute("DELETE FROM fingerprints WHERE user_id=%s", (user_id,))
    cur.execute("DELETE FROM users WHERE id=%s", (user_id,))
    db.commit()
    cur.close()
    db.close()


def fetch_all_users_with_photo():
    db = get_connection()
    cur = db.cursor()
    cur.execute("SELECT id, student_code, fullname, year, department, photo FROM users")
    users = cur.fetchall()
    cur.close()
    db.close()
    return users


def fetch_user_code(user_id):
    db = get_connection()
    cur = db.cursor()
    cur.execute("SELECT student_code FROM users WHERE id=%s", (user_id,))
    code = cur.fetchone()
    cur.close()
    db.close()
    return code[0] if code else ""


def search_closed_sessions(q=""):
    db = get_connection()
    cur = db.cursor(dictionary=True)
    likeq = f"%{q}%"
    cur.execute(
        """
        SELECT id, name, start_time, end_time
        FROM sessions
        WHERE status='' AND name LIKE %s
        ORDER BY end_time DESC
        LIMIT 30
    """,
        (likeq,),
    )
    rows = cur.fetchall()
    cur.close()
    db.close()
    return rows


def get_verified_logs_by_session(session_id):
    db = get_connection()
    cur = db.cursor(dictionary=True)
    cur.execute(
        """
        SELECT v.id, v.verify_time, v.score, u.student_code, u.fullname, u.year, u.department
        FROM verify_logs v
        LEFT JOIN users u ON v.user_id = u.id
        WHERE v.session_id = %s AND v.success=1
        ORDER BY v.verify_time DESC
        LIMIT 1000
    """,
        (session_id,),
    )
    logs = cur.fetchall()
    cur.close()
    db.close()
    return logs
